{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "c1084315",
   "metadata": {},
   "source": [
    "# 1-1,结构化数据建模流程范例"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e707554e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import datetime\n",
    "\n",
    "#打印时间\n",
    "def printbar():\n",
    "    nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')\n",
    "    print(\"\\n\"+\"==========\"*8 + \"%s\"%nowtime)\n",
    "\n",
    "#mac系统上pytorch和matplotlib在jupyter中同时跑需要更改环境变量\n",
    "os.environ[\"KMP_DUPLICATE_LIB_OK\"]=\"TRUE\" \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f74d1bbf",
   "metadata": {
    "lines_to_next_cell": 2
   },
   "outputs": [],
   "source": [
    "!pip install torch==1.10.0\n",
    "!pip install torchkeras==3.2.3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83c3b159",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch \n",
    "import torchkeras \n",
    "print(\"torch.__version__ = \", torch.__version__)\n",
    "print(\"torchkeras.__version__ = \", torchkeras.__version__) "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c6afb823",
   "metadata": {},
   "source": [
    "```\n",
    "torch.__version__ =  1.10.0\n",
    "torchkeras.__version__ =  3.2.3\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "29bd80bf",
   "metadata": {},
   "source": [
    "<br>\n",
    "\n",
    "<font color=\"red\">\n",
    " \n",
    "公众号 **算法美食屋** 回复关键词：**pytorch**， 获取本项目源码和所用数据集百度云盘下载链接。\n",
    "    \n",
    "</font> \n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "021a78c0",
   "metadata": {},
   "source": [
    "### 一，准备数据"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "22985769",
   "metadata": {},
   "source": [
    "titanic数据集的目标是根据乘客信息预测他们在Titanic号撞击冰山沉没后能否生存。\n",
    "\n",
    "结构化数据一般会使用Pandas中的DataFrame进行预处理。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "af13fa86",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np \n",
    "import pandas as pd \n",
    "import matplotlib.pyplot as plt\n",
    "import torch \n",
    "from torch import nn \n",
    "from torch.utils.data import Dataset,DataLoader,TensorDataset\n",
    "\n",
    "dftrain_raw = pd.read_csv('./eat_pytorch_datasets/titanic/train.csv')\n",
    "dftest_raw = pd.read_csv('./eat_pytorch_datasets/titanic/test.csv')\n",
    "dftrain_raw.head(10)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "04d4eadb",
   "metadata": {},
   "source": [
    "![](./data/1-1-数据集展示.jpg)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a7edf254",
   "metadata": {},
   "source": [
    "字段说明：\n",
    "\n",
    "* Survived:0代表死亡，1代表存活【y标签】\n",
    "* Pclass:乘客所持票类，有三种值(1,2,3) 【转换成onehot编码】\n",
    "* Name:乘客姓名 【舍去】\n",
    "* Sex:乘客性别 【转换成bool特征】\n",
    "* Age:乘客年龄(有缺失) 【数值特征，添加“年龄是否缺失”作为辅助特征】\n",
    "* SibSp:乘客兄弟姐妹/配偶的个数(整数值) 【数值特征】\n",
    "* Parch:乘客父母/孩子的个数(整数值)【数值特征】\n",
    "* Ticket:票号(字符串)【舍去】\n",
    "* Fare:乘客所持票的价格(浮点数，0-500不等) 【数值特征】\n",
    "* Cabin:乘客所在船舱(有缺失) 【添加“所在船舱是否缺失”作为辅助特征】\n",
    "* Embarked:乘客登船港口:S、C、Q(有缺失)【转换成onehot编码，四维度 S,C,Q,nan】\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9a726f38",
   "metadata": {},
   "source": [
    "利用Pandas的数据可视化功能我们可以简单地进行探索性数据分析EDA（Exploratory Data Analysis）。\n",
    "\n",
    "label分布情况"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "78e13165",
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "%config InlineBackend.figure_format = 'png'\n",
    "ax = dftrain_raw['Survived'].value_counts().plot(kind = 'bar',\n",
    "     figsize = (12,8),fontsize=15,rot = 0)\n",
    "ax.set_ylabel('Counts',fontsize = 15)\n",
    "ax.set_xlabel('Survived',fontsize = 15)\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "20641f47",
   "metadata": {},
   "source": [
    "![](./data/1-1-Label分布.jpg)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2f6fa85c",
   "metadata": {},
   "source": [
    "年龄分布情况"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2284f17f",
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "%config InlineBackend.figure_format = 'png'\n",
    "ax = dftrain_raw['Age'].plot(kind = 'hist',bins = 20,color= 'purple',\n",
    "                    figsize = (12,8),fontsize=15)\n",
    "\n",
    "ax.set_ylabel('Frequency',fontsize = 15)\n",
    "ax.set_xlabel('Age',fontsize = 15)\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "58763f84",
   "metadata": {},
   "source": [
    "![](./data/1-1-年龄分布.jpg)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "25ab6193",
   "metadata": {},
   "source": [
    "年龄和label的相关性"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e8a0979c",
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "%config InlineBackend.figure_format = 'png'\n",
    "ax = dftrain_raw.query('Survived == 0')['Age'].plot(kind = 'density',\n",
    "                      figsize = (12,8),fontsize=15)\n",
    "dftrain_raw.query('Survived == 1')['Age'].plot(kind = 'density',\n",
    "                      figsize = (12,8),fontsize=15)\n",
    "ax.legend(['Survived==0','Survived==1'],fontsize = 12)\n",
    "ax.set_ylabel('Density',fontsize = 15)\n",
    "ax.set_xlabel('Age',fontsize = 15)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "23dab310",
   "metadata": {},
   "source": [
    "![](./data/1-1-年龄相关性.jpg)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d4ba0563",
   "metadata": {},
   "source": [
    "下面为正式的数据预处理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a28dcead",
   "metadata": {},
   "outputs": [],
   "source": [
    "def preprocessing(dfdata):\n",
    "\n",
    "    dfresult= pd.DataFrame()\n",
    "\n",
    "    #Pclass\n",
    "    dfPclass = pd.get_dummies(dfdata['Pclass'])\n",
    "    dfPclass.columns = ['Pclass_' +str(x) for x in dfPclass.columns ]\n",
    "    dfresult = pd.concat([dfresult,dfPclass],axis = 1)\n",
    "\n",
    "    #Sex\n",
    "    dfSex = pd.get_dummies(dfdata['Sex'])\n",
    "    dfresult = pd.concat([dfresult,dfSex],axis = 1)\n",
    "\n",
    "    #Age\n",
    "    dfresult['Age'] = dfdata['Age'].fillna(0)\n",
    "    dfresult['Age_null'] = pd.isna(dfdata['Age']).astype('int32')\n",
    "\n",
    "    #SibSp,Parch,Fare\n",
    "    dfresult['SibSp'] = dfdata['SibSp']\n",
    "    dfresult['Parch'] = dfdata['Parch']\n",
    "    dfresult['Fare'] = dfdata['Fare']\n",
    "\n",
    "    #Carbin\n",
    "    dfresult['Cabin_null'] =  pd.isna(dfdata['Cabin']).astype('int32')\n",
    "\n",
    "    #Embarked\n",
    "    dfEmbarked = pd.get_dummies(dfdata['Embarked'],dummy_na=True)\n",
    "    dfEmbarked.columns = ['Embarked_' + str(x) for x in dfEmbarked.columns]\n",
    "    dfresult = pd.concat([dfresult,dfEmbarked],axis = 1)\n",
    "\n",
    "    return(dfresult)\n",
    "\n",
    "x_train = preprocessing(dftrain_raw).values\n",
    "y_train = dftrain_raw[['Survived']].values\n",
    "\n",
    "x_test = preprocessing(dftest_raw).values\n",
    "y_test = dftest_raw[['Survived']].values\n",
    "\n",
    "print(\"x_train.shape =\", x_train.shape )\n",
    "print(\"x_test.shape =\", x_test.shape )\n",
    "\n",
    "print(\"y_train.shape =\", y_train.shape )\n",
    "print(\"y_test.shape =\", y_test.shape )\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e96d2734",
   "metadata": {},
   "source": [
    "```\n",
    "x_train.shape = (712, 15)\n",
    "x_test.shape = (179, 15)\n",
    "y_train.shape = (712, 1)\n",
    "y_test.shape = (179, 1)\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "95d6c6d9",
   "metadata": {},
   "source": [
    "进一步使用DataLoader和TensorDataset封装成可以迭代的数据管道。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d744935",
   "metadata": {},
   "outputs": [],
   "source": [
    "dl_train = DataLoader(TensorDataset(torch.tensor(x_train).float(),torch.tensor(y_train).float()),\n",
    "                     shuffle = True, batch_size = 8)\n",
    "dl_val = DataLoader(TensorDataset(torch.tensor(x_test).float(),torch.tensor(y_test).float()),\n",
    "                     shuffle = False, batch_size = 8)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ec2fc6b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 测试数据管道\n",
    "for features,labels in dl_train:\n",
    "    print(features,labels)\n",
    "    break"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8cbd018d",
   "metadata": {},
   "source": [
    "```\n",
    "tensor([[  0.0000,   0.0000,   1.0000,   0.0000,   1.0000,   0.0000,   1.0000,\n",
    "           0.0000,   0.0000,   7.8958,   1.0000,   0.0000,   0.0000,   1.0000,\n",
    "           0.0000],\n",
    "        [  1.0000,   0.0000,   0.0000,   0.0000,   1.0000,   0.0000,   1.0000,\n",
    "           0.0000,   0.0000,  30.5000,   0.0000,   0.0000,   0.0000,   1.0000,\n",
    "           0.0000],\n",
    "        [  1.0000,   0.0000,   0.0000,   1.0000,   0.0000,  31.0000,   0.0000,\n",
    "           1.0000,   0.0000, 113.2750,   0.0000,   1.0000,   0.0000,   0.0000,\n",
    "           0.0000],\n",
    "        [  1.0000,   0.0000,   0.0000,   0.0000,   1.0000,  60.0000,   0.0000,\n",
    "           0.0000,   0.0000,  26.5500,   1.0000,   0.0000,   0.0000,   1.0000,\n",
    "           0.0000],\n",
    "        [  0.0000,   0.0000,   1.0000,   0.0000,   1.0000,  28.0000,   0.0000,\n",
    "           0.0000,   0.0000,  22.5250,   1.0000,   0.0000,   0.0000,   1.0000,\n",
    "           0.0000],\n",
    "        [  0.0000,   0.0000,   1.0000,   0.0000,   1.0000,  32.0000,   0.0000,\n",
    "           0.0000,   0.0000,   8.3625,   1.0000,   0.0000,   0.0000,   1.0000,\n",
    "           0.0000],\n",
    "        [  0.0000,   1.0000,   0.0000,   1.0000,   0.0000,  28.0000,   0.0000,\n",
    "           0.0000,   0.0000,  13.0000,   1.0000,   0.0000,   0.0000,   1.0000,\n",
    "           0.0000],\n",
    "        [  1.0000,   0.0000,   0.0000,   0.0000,   1.0000,  36.0000,   0.0000,\n",
    "           0.0000,   1.0000, 512.3292,   0.0000,   1.0000,   0.0000,   0.0000,\n",
    "           0.0000]]) tensor([[0.],\n",
    "        [1.],\n",
    "        [1.],\n",
    "        [0.],\n",
    "        [0.],\n",
    "        [0.],\n",
    "        [1.],\n",
    "        [1.]])\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0008e20c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "25126768",
   "metadata": {},
   "source": [
    "### 二，定义模型"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3a275e98",
   "metadata": {},
   "source": [
    "使用Pytorch通常有三种方式构建模型：使用nn.Sequential按层顺序构建模型，继承nn.Module基类构建自定义模型，继承nn.Module基类构建模型并辅助应用模型容器进行封装。\n",
    "\n",
    "此处选择使用最简单的nn.Sequential，按层顺序模型。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "617186ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_net():\n",
    "    net = nn.Sequential()\n",
    "    net.add_module(\"linear1\",nn.Linear(15,20))\n",
    "    net.add_module(\"relu1\",nn.ReLU())\n",
    "    net.add_module(\"linear2\",nn.Linear(20,15))\n",
    "    net.add_module(\"relu2\",nn.ReLU())\n",
    "    net.add_module(\"linear3\",nn.Linear(15,1))\n",
    "    return net\n",
    "    \n",
    "net = create_net()\n",
    "print(net)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cdef0374",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "6761227a",
   "metadata": {},
   "source": [
    "### 三，训练模型"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "af89d5af",
   "metadata": {},
   "source": [
    "Pytorch通常需要用户编写自定义训练循环，训练循环的代码风格因人而异。\n",
    "\n",
    "有3类典型的训练循环代码风格：脚本形式训练循环，函数形式训练循环，类形式训练循环。\n",
    "\n",
    "此处介绍一种较通用的仿照Keras风格的脚本形式的训练循环。\n",
    "\n",
    "该脚本形式的训练代码与 torchkeras 库的核心代码基本一致。\n",
    "\n",
    "torchkeras详情:  https://github.com/lyhue1991/torchkeras \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "87c5039d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os,sys,time\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import datetime \n",
    "from tqdm import tqdm \n",
    "\n",
    "import torch\n",
    "from torch import nn \n",
    "from copy import deepcopy\n",
    "from torchkeras.metrics import Accuracy\n",
    "\n",
    "\n",
    "def printlog(info):\n",
    "    nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')\n",
    "    print(\"\\n\"+\"==========\"*8 + \"%s\"%nowtime)\n",
    "    print(str(info)+\"\\n\")\n",
    "    \n",
    "\n",
    "loss_fn = nn.BCEWithLogitsLoss()\n",
    "optimizer= torch.optim.Adam(net.parameters(),lr = 0.01)   \n",
    "metrics_dict = {\"acc\":Accuracy()}\n",
    "\n",
    "epochs = 20 \n",
    "ckpt_path='checkpoint.pt'\n",
    "\n",
    "#early_stopping相关设置\n",
    "monitor=\"val_acc\"\n",
    "patience=5\n",
    "mode=\"max\"\n",
    "\n",
    "history = {}\n",
    "\n",
    "for epoch in range(1, epochs+1):\n",
    "    printlog(\"Epoch {0} / {1}\".format(epoch, epochs))\n",
    "\n",
    "    # 1，train -------------------------------------------------  \n",
    "    net.train()\n",
    "    \n",
    "    total_loss,step = 0,0\n",
    "    \n",
    "    loop = tqdm(enumerate(dl_train), total =len(dl_train))\n",
    "    train_metrics_dict = deepcopy(metrics_dict) \n",
    "    \n",
    "    for i, batch in loop: \n",
    "        \n",
    "        features,labels = batch\n",
    "        #forward\n",
    "        preds = net(features)\n",
    "        loss = loss_fn(preds,labels)\n",
    "        \n",
    "        #backward\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        optimizer.zero_grad()\n",
    "            \n",
    "        #metrics\n",
    "        step_metrics = {\"train_\"+name:metric_fn(preds, labels).item() \n",
    "                        for name,metric_fn in train_metrics_dict.items()}\n",
    "        \n",
    "        step_log = dict({\"train_loss\":loss.item()},**step_metrics)\n",
    "\n",
    "        total_loss += loss.item()\n",
    "        \n",
    "        step+=1\n",
    "        if i!=len(dl_train)-1:\n",
    "            loop.set_postfix(**step_log)\n",
    "        else:\n",
    "            epoch_loss = total_loss/step\n",
    "            epoch_metrics = {\"train_\"+name:metric_fn.compute().item() \n",
    "                             for name,metric_fn in train_metrics_dict.items()}\n",
    "            epoch_log = dict({\"train_loss\":epoch_loss},**epoch_metrics)\n",
    "            loop.set_postfix(**epoch_log)\n",
    "\n",
    "            for name,metric_fn in train_metrics_dict.items():\n",
    "                metric_fn.reset()\n",
    "                \n",
    "    for name, metric in epoch_log.items():\n",
    "        history[name] = history.get(name, []) + [metric]\n",
    "        \n",
    "\n",
    "    # 2，validate -------------------------------------------------\n",
    "    net.eval()\n",
    "    \n",
    "    total_loss,step = 0,0\n",
    "    loop = tqdm(enumerate(dl_val), total =len(dl_val))\n",
    "    \n",
    "    val_metrics_dict = deepcopy(metrics_dict) \n",
    "    \n",
    "    with torch.no_grad():\n",
    "        for i, batch in loop: \n",
    "\n",
    "            features,labels = batch\n",
    "            \n",
    "            #forward\n",
    "            preds = net(features)\n",
    "            loss = loss_fn(preds,labels)\n",
    "\n",
    "            #metrics\n",
    "            step_metrics = {\"val_\"+name:metric_fn(preds, labels).item() \n",
    "                            for name,metric_fn in val_metrics_dict.items()}\n",
    "\n",
    "            step_log = dict({\"val_loss\":loss.item()},**step_metrics)\n",
    "\n",
    "            total_loss += loss.item()\n",
    "            step+=1\n",
    "            if i!=len(dl_val)-1:\n",
    "                loop.set_postfix(**step_log)\n",
    "            else:\n",
    "                epoch_loss = (total_loss/step)\n",
    "                epoch_metrics = {\"val_\"+name:metric_fn.compute().item() \n",
    "                                 for name,metric_fn in val_metrics_dict.items()}\n",
    "                epoch_log = dict({\"val_loss\":epoch_loss},**epoch_metrics)\n",
    "                loop.set_postfix(**epoch_log)\n",
    "\n",
    "                for name,metric_fn in val_metrics_dict.items():\n",
    "                    metric_fn.reset()\n",
    "                    \n",
    "    epoch_log[\"epoch\"] = epoch           \n",
    "    for name, metric in epoch_log.items():\n",
    "        history[name] = history.get(name, []) + [metric]\n",
    "\n",
    "    # 3，early-stopping -------------------------------------------------\n",
    "    arr_scores = history[monitor]\n",
    "    best_score_idx = np.argmax(arr_scores) if mode==\"max\" else np.argmin(arr_scores)\n",
    "    if best_score_idx==len(arr_scores)-1:\n",
    "        torch.save(net.state_dict(),ckpt_path)\n",
    "        print(\"<<<<<< reach best {0} : {1} >>>>>>\".format(monitor,\n",
    "             arr_scores[best_score_idx]),file=sys.stderr)\n",
    "    if len(arr_scores)-best_score_idx>patience:\n",
    "        print(\"<<<<<< {} without improvement in {} epoch, early stopping >>>>>>\".format(\n",
    "            monitor,patience),file=sys.stderr)\n",
    "        break \n",
    "    net.load_state_dict(torch.load(ckpt_path))\n",
    "    \n",
    "dfhistory = pd.DataFrame(history)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7b038d69",
   "metadata": {},
   "source": [
    "```\n",
    "================================================================================2022-07-10 21:55:18\n",
    "Epoch 1 / 20\n",
    "\n",
    "100%|██████████| 89/89 [00:00<00:00, 192.16it/s, train_acc=0.664, train_loss=0.646]\n",
    "100%|██████████| 23/23 [00:00<00:00, 252.37it/s, val_acc=0.721, val_loss=0.571]\n",
    "<<<<<< reach best val_acc : 0.7206704020500183 >>>>>>\n",
    "\n",
    "================================================================================2022-07-10 21:55:19\n",
    "Epoch 2 / 20\n",
    "\n",
    "100%|██████████| 89/89 [00:00<00:00, 212.44it/s, train_acc=0.725, train_loss=0.576]\n",
    "100%|██████████| 23/23 [00:00<00:00, 183.68it/s, val_acc=0.726, val_loss=0.503]\n",
    "<<<<<< reach best val_acc : 0.7262569665908813 >>>>>>\n",
    "\n",
    "================================================================================2022-07-10 21:55:19\n",
    "Epoch 3 / 20\n",
    "\n",
    "100%|██████████| 89/89 [00:00<00:00, 128.57it/s, train_acc=0.772, train_loss=0.517]\n",
    "100%|██████████| 23/23 [00:00<00:00, 195.21it/s, val_acc=0.782, val_loss=0.445]\n",
    "<<<<<< reach best val_acc : 0.7821229100227356 >>>>>>\n",
    "\n",
    "================================================================================2022-07-10 21:55:20\n",
    "Epoch 4 / 20\n",
    "\n",
    "100%|██████████| 89/89 [00:00<00:00, 139.91it/s, train_acc=0.784, train_loss=0.495]\n",
    "100%|██████████| 23/23 [00:00<00:00, 281.71it/s, val_acc=0.793, val_loss=0.435]\n",
    "<<<<<< reach best val_acc : 0.7932960987091064 >>>>>>\n",
    "\n",
    "================================================================================2022-07-10 21:55:21\n",
    "Epoch 5 / 20\n",
    "\n",
    "100%|██████████| 89/89 [00:00<00:00, 216.33it/s, train_acc=0.788, train_loss=0.493]\n",
    "100%|██████████| 23/23 [00:00<00:00, 246.54it/s, val_acc=0.81, val_loss=0.409]\n",
    "<<<<<< reach best val_acc : 0.8100558519363403 >>>>>>\n",
    "\n",
    "================================================================================2022-07-10 21:55:21\n",
    "Epoch 6 / 20\n",
    "\n",
    "100%|██████████| 89/89 [00:00<00:00, 191.69it/s, train_acc=0.765, train_loss=0.481]\n",
    "100%|██████████| 23/23 [00:00<00:00, 251.35it/s, val_acc=0.777, val_loss=0.436]\n",
    "\n",
    "================================================================================2022-07-10 21:55:22\n",
    "Epoch 7 / 20\n",
    "\n",
    "100%|██████████| 89/89 [00:00<00:00, 192.42it/s, train_acc=0.781, train_loss=0.493]\n",
    "100%|██████████| 23/23 [00:00<00:00, 241.61it/s, val_acc=0.771, val_loss=0.462]\n",
    "\n",
    "================================================================================2022-07-10 21:55:22\n",
    "Epoch 8 / 20\n",
    "\n",
    "100%|██████████| 89/89 [00:00<00:00, 211.52it/s, train_acc=0.801, train_loss=0.475]\n",
    "100%|██████████| 23/23 [00:00<00:00, 263.07it/s, val_acc=0.793, val_loss=0.406]\n",
    "\n",
    "================================================================================2022-07-10 21:55:23\n",
    "Epoch 9 / 20\n",
    "\n",
    "100%|██████████| 89/89 [00:00<00:00, 199.20it/s, train_acc=0.798, train_loss=0.444]\n",
    "100%|██████████| 23/23 [00:00<00:00, 265.92it/s, val_acc=0.782, val_loss=0.43]\n",
    "\n",
    "================================================================================2022-07-10 21:55:23\n",
    "Epoch 10 / 20\n",
    "\n",
    "100%|██████████| 89/89 [00:00<00:00, 193.12it/s, train_acc=0.81, train_loss=0.445] \n",
    "100%|██████████| 23/23 [00:00<00:00, 259.94it/s, val_acc=0.771, val_loss=0.506]\n",
    "<<<<<< val_acc without improvement in 5 epoch, early stopping >>>>>>\n",
    "\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c5ba3ad",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec78a930",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "70f7f3b3",
   "metadata": {},
   "source": [
    "### 四，评估模型"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cefbce96",
   "metadata": {},
   "source": [
    "我们首先评估一下模型在训练集和验证集上的效果。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2a731173",
   "metadata": {},
   "outputs": [],
   "source": [
    "dfhistory "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "841b8193",
   "metadata": {},
   "source": [
    "![](./data/1-1-dfhistory.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10ab56d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "%config InlineBackend.figure_format = 'svg'\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "def plot_metric(dfhistory, metric):\n",
    "    train_metrics = dfhistory[\"train_\"+metric]\n",
    "    val_metrics = dfhistory['val_'+metric]\n",
    "    epochs = range(1, len(train_metrics) + 1)\n",
    "    plt.plot(epochs, train_metrics, 'bo--')\n",
    "    plt.plot(epochs, val_metrics, 'ro-')\n",
    "    plt.title('Training and validation '+ metric)\n",
    "    plt.xlabel(\"Epochs\")\n",
    "    plt.ylabel(metric)\n",
    "    plt.legend([\"train_\"+metric, 'val_'+metric])\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3b47c77",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_metric(dfhistory,\"loss\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "077a685c",
   "metadata": {},
   "source": [
    "![](https://tva1.sinaimg.cn/large/e6c9d24egy1h426f4kjqfj20fy0a9q3a.jpg)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4bdfcf5a",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_metric(dfhistory,\"acc\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d9263b7b",
   "metadata": {},
   "source": [
    "![](https://tva1.sinaimg.cn/large/e6c9d24egy1h426dvo2upj20fy0a9t92.jpg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c1f1cb9e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17d7e80c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "05c4494b",
   "metadata": {},
   "source": [
    "### 五，使用模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da16f9ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "#预测概率\n",
    "\n",
    "y_pred_probs = torch.sigmoid(net(torch.tensor(x_test[0:10]).float())).data\n",
    "y_pred_probs"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bde97226",
   "metadata": {},
   "source": [
    "```\n",
    "tensor([[0.1146],\n",
    "        [0.6517],\n",
    "        [0.4307],\n",
    "        [0.8692],\n",
    "        [0.5542],\n",
    "        [0.7894],\n",
    "        [0.1096],\n",
    "        [0.7125],\n",
    "        [0.6027],\n",
    "        [0.1139]])\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a31cb281",
   "metadata": {},
   "outputs": [],
   "source": [
    "#预测类别\n",
    "y_pred = torch.where(y_pred_probs>0.5,\n",
    "        torch.ones_like(y_pred_probs),torch.zeros_like(y_pred_probs))\n",
    "y_pred"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "574bdfb5",
   "metadata": {},
   "source": [
    "```\n",
    "tensor([[0.],\n",
    "        [1.],\n",
    "        [0.],\n",
    "        [1.],\n",
    "        [1.],\n",
    "        [1.],\n",
    "        [0.],\n",
    "        [1.],\n",
    "        [1.],\n",
    "        [0.]])\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "48f6cacc",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "57fe8bba",
   "metadata": {},
   "source": [
    "### 六，保存模型"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "89eb7506",
   "metadata": {},
   "source": [
    "Pytorch 有两种保存模型的方式，都是通过调用pickle序列化方法实现的。\n",
    "\n",
    "第一种方法只保存模型参数。\n",
    "\n",
    "第二种方法保存完整模型。\n",
    "\n",
    "推荐使用第一种，第二种方法可能在切换设备和目录的时候出现各种问题。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9113eb43",
   "metadata": {},
   "source": [
    "**1，保存模型参数(推荐)**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6098000",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(net.state_dict().keys())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bb2c5f22",
   "metadata": {},
   "source": [
    "```\n",
    "odict_keys(['linear1.weight', 'linear1.bias', 'linear2.weight', 'linear2.bias', 'linear3.weight', 'linear3.bias'])\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4cfa68ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 保存模型参数\n",
    "\n",
    "torch.save(net.state_dict(), \"./data/net_parameter.pt\")\n",
    "\n",
    "net_clone = create_net()\n",
    "net_clone.load_state_dict(torch.load(\"./data/net_parameter.pt\"))\n",
    "\n",
    "torch.sigmoid(net_clone.forward(torch.tensor(x_test[0:10]).float())).data\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "47260e15",
   "metadata": {},
   "source": [
    "```\n",
    "tensor([[0.1146],\n",
    "        [0.6517],\n",
    "        [0.4307],\n",
    "        [0.8692],\n",
    "        [0.5542],\n",
    "        [0.7894],\n",
    "        [0.1096],\n",
    "        [0.7125],\n",
    "        [0.6027],\n",
    "        [0.1139]])\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1fcfbadf",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "bee51ce5",
   "metadata": {},
   "source": [
    "**2，保存完整模型(不推荐)**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c969c33",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "torch.save(net, './data/net_model.pt')\n",
    "net_loaded = torch.load('./data/net_model.pt')\n",
    "torch.sigmoid(net_loaded(torch.tensor(x_test[0:10]).float())).data\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dbb6ab5e",
   "metadata": {},
   "source": [
    "```\n",
    "tensor([[0.1146],\n",
    "        [0.6517],\n",
    "        [0.4307],\n",
    "        [0.8692],\n",
    "        [0.5542],\n",
    "        [0.7894],\n",
    "        [0.1096],\n",
    "        [0.7125],\n",
    "        [0.6027],\n",
    "        [0.1139]])\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f7a2e7d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "52eacb75",
   "metadata": {},
   "source": [
    "**如果本书对你有所帮助，想鼓励一下作者，记得给本项目加一颗星星star⭐️，并分享给你的朋友们喔😊!** \n",
    "\n",
    "如果对本书内容理解上有需要进一步和作者交流的地方，欢迎在公众号\"算法美食屋\"下留言。作者时间和精力有限，会酌情予以回复。\n",
    "\n",
    "也可以在公众号后台回复关键字：**加群**，加入读者交流群和大家讨论。\n",
    "\n",
    "![算法美食屋logo.png](https://tva1.sinaimg.cn/large/e6c9d24egy1h41m2zugguj20k00b9q46.jpg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1bdbcc33",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "jupytext": {
   "cell_metadata_filter": "-all",
   "formats": "ipynb,md",
   "main_language": "python"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
